from Utils import *

all_targets = [] 

gender_df = pd.read_csv("identity_terms/gender.csv") # TERM,POS removing duplicates w.r.t. TERM 
sexuality_df = pd.read_csv("identity_terms/sexuality.csv") # TERM,POS removing duplicates w.r.t. TERM 
race_df = pd.read_csv("identity_terms/race.csv") # TERM, missing POS but they're adj
countries_df = pd.read_csv("identity_terms/countries.csv") # COUNTRY_ADJ and REGION_ADJ removing duplicates
religion_df = pd.read_csv("identity_terms/religion.csv") # TERM w SEM == person/"", POS removing duplicates w.r.t. TERM 
religion_df = religion_df[(religion_df["SEM"] == "person") | (religion_df["SEM"] == "")]
disability_df = pd.read_csv("identity_terms/disability.csv") # TERM,POS removing duplicates w.r.t. TERM  
gender_df,all_targets = filter(all_targets,gender_df)
sexuality_df,all_targets = filter(all_targets,sexuality_df)
religion_df,all_targets = filter(all_targets,religion_df)
disability_df,all_targets = filter(all_targets,disability_df,True)
regions_df,all_targets = filter_exceptions(all_targets,countries_df,False,False)
countries_df,all_targets = filter_exceptions(all_targets,countries_df,False,True)
race_df,all_targets = filter_exceptions(all_targets,race_df,True)
gender_sexuality_df = gender_df + sexuality_df
race_countries_df = race_df + regions_df + countries_df

data_dict = {
    "gender": [item for item in gender_sexuality_df if item != 'nan'],
    "race":  [item for item in race_countries_df if item != 'nan'],
    "culture":  [item for item in religion_df if item != 'nan'],
    "disabled":  [item for item in disability_df if item != 'nan'] 
}

perplexity = load("perplexity", module_type="metric") 
PPL = {}

for key, value in data_dict.items():
    data_dict[key] = [item for item in value if item != "nan"]
    data_dict[key] = [x.capitalize() if isinstance(x, str) else x for x in data_dict[key]]

for key, value in data_dict.items():
    for LM in LMs:
        perplexities = perplexity.compute(model_id=LM, predictions=value)
        PPL[LM] = [round(x, 3) for x in perplexities['perplexities']]
        print('\n <----------------------> END of ' + LM + '\n')
    identities_w_PPL = pd.DataFrame(list(zip(value, *PPL.values())), columns=["identity"] + list(PPL.keys()))
    identities_w_PPL = identities_w_PPL.rename(columns=LMs_columns_names)
    file_name = key + '-identities-w-PPLs.csv'
    identities_w_PPL.to_csv('../Social Bias Probing/'+file_name, index=False)
    print('\n\n <----------------------> END of ' + key + '\n\n')